#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <netinet/in.h>
#include <binary_protocol.h>
#include "general.h"
#include "util.h"
#include "mc.h"

#include <event2/event.h>
#include <event2/listener.h>
#include <event2/bufferevent.h>
#include <event2/thread.h>
#include <assert.h>
#include <general.h>

void do_accept(evutil_socket_t listener, short event, void *arg);

void socket_read_cb(struct bufferevent *bev, void *arg);
void socket_event_cb(struct bufferevent *bev, short events, void *arg);
void proxy_r_server_cb(struct bufferevent *bev, void *arg);
void proxy_r_cli_cb(struct bufferevent *bev, void *arg);

struct bufferevent **bev_s;
struct bufferevent **bev_c;
pthread_t **_workers, _time_thread;

/*
void *send_thread(void *arg) {
    proxy_conn_t *conn = (proxy_conn_t *)arg;
    proxy_conn_info **lp = conn->conn_list;
    int rlen, slen;
    while (1) {
        fprintf(stderr, "enter send_thread\n");
        for (int i=0; i<conn->num_conn; i++) {
            fprintf(stderr, "conn info server=%s port=%d rank=%d sfd=%d\n",
                    lp[i]->conn.server, lp[i]->conn.port, lp[i]->conn.rank, lp[i]->conn.sfd);
            rlen = recv_request(conn->proxy_fd, lp[i]->sndbuf, 24);
            if (24 == rlen) {
                binary_header_t *h = (binary_header_t *)lp[i]->rcvbuf;
                rlen = recv_request(conn->proxy_fd, lp[i]->sndbuf+24, h->body_len);
                if (rlen == h->body_len) {
                    fprintf(stderr, "success receive client request body len=%d\n", rlen);
                    send_request(lp[i]->conn.port, lp[i]->sndbuf, 24+rlen); //send to server
                }
            }
        }
        sleep(1);
        break;
    }
    return (void *)5;
}

void *recv_thread(void *arg) {
    proxy_conn_t *conn = (proxy_conn_t *)arg;
    proxy_conn_info **lp = conn->conn_list;
    fprintf(stderr, "enter recv_thread\n");
    ssize_t len;
    while (1) {
        fprintf(stderr, "conn num_conn=%d proxy_fd=%d\n", conn->num_conn, conn->proxy_fd);
        for (int i=0; i<conn->num_conn; i++) {
            //fprintf(stderr, "conn info server=%s port=%d rank=%d sfd=%d\n",
            //        lp[i]->conn.server, lp[i]->conn.port, lp[i]->conn.rank, lp[i]->conn.sfd);
            len = recv_request(lp[i]->conn.sfd, lp[i]->rcvbuf, 24);  //recv from server
            if (24 == len) {
                binary_header_t *h = (binary_header_t *)lp[i]->rcvbuf;
                len = recv_request(lp[i]->conn.sfd, lp[i]->rcvbuf+24, h->body_len);
                if (len == h->body_len) {
                    fprintf(stderr, "success receive server response body len=%zd\n", len);
                    send_request(conn->proxy_fd, lp[i]->rcvbuf, 24+len); //send to client
                } else {
                    fprintf(stderr, "error recv_request body len=%zd\n", len);
                }

            } else {
                fprintf(stderr, "error recv_request header len=%zd\n", len);
            }
        }
        sleep(1);
        break;
    }
    return (void *)7;
}

void report_pthread_error(int rc, char *str, int tid) {
    if (rc != 0) {
        fprintf(stderr, "ERROR %s tid=%d rc=%s\n", str, tid, strerror(rc));
        exit(-1);
    } else {
        fprintf(stderr, "%s tid=%d %s\n", str, tid, strerror(rc));
    }
}

void alloc_thread(int num) {
    _workers = (pthread_t **)malloc(num * sizeof(pthread_t *));
    for (int i=0; i<num; i++) {
        _workers[i] = (pthread_t *)malloc(sizeof(pthread_t));
    }
}

void release_thread(pthread_t **t, int num) {
    for (int i=0; i<num; i++) {
        free(t[i]);
    }
    free(t);
}
 */

int main(int argc, char **argv) {
    pthread_attr_t attr;
    int rc;
    void *status;
    //int  *server_fds;
    int proxy_fd, accept_fd;
    struct sockaddr_in accept_addr;

    pthread_attr_init(&attr);
    pthread_attr_setdetachstate(&attr, PTHREAD_CREATE_JOINABLE);
    parse_ini_file(argv[1]);

    //server_fds = alloc_fds(s_config.num_servers);
    //init_server_fds(s_config.num_servers, server_fds);

    // or use argv[2]
    char hostname[32];
    gethostname(hostname, 32);
    int proxy_port = find_port_by_server(hostname);
    if (proxy_port > 0) {
        fprintf(stderr, "get proxy_port %d <= %s\n", proxy_port, hostname);
    } else {
        fprintf(stderr, "cannot get proxy_port %d <= %s\n", proxy_port, hostname);
        exit(EXIT_FAILURE);
    }

    proxy_fd = serversock(proxy_port);
    fprintf(stderr, "create listening proxy_fd=%d\n", proxy_fd);

    evutil_make_socket_nonblocking(proxy_fd);

    bev_s = (struct bufferevent **)malloc(sizeof(struct bufferevent *) * s_config.num_servers);
    bev_c = (struct bufferevent **)malloc(
            sizeof(struct bufferevent *) * s_config.num_clients * s_config.num_conn_per_client);
    struct bufferevent *pbe;

    struct event_base *base = event_base_new();
    assert(NULL != base);
    bf_map.num_entries = 0;
    bf_map.entries = (bev_fd_map_entry *)malloc(
            sizeof(bev_fd_map_entry) * s_config.num_clients * s_config.num_conn_per_client);
    struct event *event_read = event_new(base, proxy_fd, EV_READ|EV_PERSIST, do_accept, (void *)base);
    event_add(event_read, NULL);


    init_proxy_conn_server_list(&pcs_list, &s_config);
    pcc_list.pcc_list = (conn_info **)malloc(sizeof(conn_info *)*s_config.num_clients * s_config.num_conn_per_client);
    if (NULL == pcc_list.pcc_list)
        exit(EXIT_FAILURE);
    fprintf(stderr, "finish init_proxy_conn_server_list\n");

    for (int i=0; i<s_config.num_servers; i++) {
        fprintf(stderr, "try to init server bev=%d\n", i);
        evutil_make_socket_nonblocking(pcs_list.pcs_list[i]->conn.sfd);
        pbe = bufferevent_socket_new(base, pcs_list.pcs_list[i]->conn.sfd, BEV_OPT_CLOSE_ON_FREE);
        bufferevent_setcb(pbe, proxy_r_server_cb, NULL, socket_event_cb, (void *)&pcc_list);
        bufferevent_enable(pbe, EV_READ|EV_PERSIST);
        bev_s[i] = pbe;
    }
    fprintf(stderr, "finish create server bufferevent\n");

    event_base_dispatch(base);
    fprintf(stderr, "finish dispatch\n");

    release_proxy_conn_server_list(&pcs_list, s_config.num_servers);
    if (bev_s != NULL) {
        for (int i = 0; i < s_config.num_servers; i++) {
            bufferevent_free(bev_s[i]);
        }
        free(bev_s);
    }
    if (bev_c != NULL) {
        for (int i = 0; i < s_config.num_clients*s_config.num_conn_per_client; i++) {
            bufferevent_free(bev_c[i]);
        }
        free(bev_c);
    }

    event_base_free(base);
    fprintf(stderr, "finish free base\n");

    /*
    int num_workers = 2;
    alloc_thread(num_workers);

    rc = pthread_create(_workers[0], &attr, send_thread, (void *)&proxy_conn);
    report_pthread_error(rc, "#create send_thread", 0);

    rc = pthread_create(_workers[1], &attr, recv_thread, (void *)&proxy_conn);
    report_pthread_error(rc, "#create recv_thread", 1);

    //socklen_t accept_addr_len = sizeof(accept_addr);
    //accept_fd = accept(proxy_fd, (struct sockaddr *)&accept_addr, &accept_addr_len);
    //if (accept_fd < 0) {
    //    fprintf(stderr, "error on proxy accept\n");
    //    exit(EXIT_FAILURE);
    //}

    for (int i=0; i<num_workers; i++) {
        rc = pthread_join(*(_workers[i]), &status);
        report_pthread_error(rc, "#join worker_thread", i);
    }
     */

    //pthread_attr_destroy(&attr);
    //release_thread(_workers, s_config.num_servers);
    //release_fds(server_fds);
    //release_proxy_conn(s_config.num_servers);

    //close(accept_fd);
    //close(proxy_fd);
    return 0;
}

void do_accept(evutil_socket_t efd, short event, void *arg) {
    fprintf(stderr, "accept client listen_fd=%d\n", efd);
    struct event_base *base = (struct event_base *)arg;
    evutil_socket_t fd;
    struct sockaddr_in sin;
    socklen_t slen;
    fd = accept(efd, (struct sockaddr *)&sin, &slen);
    if (fd < 0) {
        perror("do_accept() accept");
        return;
    }
    if (fd > FD_SETSIZE) {
        perror("fd > FD_SETSIZE");
        return;
    }

    struct bufferevent *bev =  bufferevent_socket_new(base, fd, BEV_OPT_CLOSE_ON_FREE);
    ////bufferevent_setcb(bev, socket_read_cb, NULL, socket_event_cb, (void *)&proxy_conn);
    bufferevent_setcb(bev, proxy_r_cli_cb, NULL, socket_event_cb, (void *)&pcs_list);
    bufferevent_enable(bev, EV_READ|EV_WRITE|EV_PERSIST);
    bf_map.entries[bf_map.num_entries].bev = bev;
    bf_map.entries[bf_map.num_entries].index = bf_map.num_entries;
    bf_map.entries[bf_map.num_entries].mfd = fd;
    bev_c[bf_map.num_entries] = bev;
    fprintf(stderr, "do_accept() fd=%u bev_idx=%d\n\n", fd, bf_map.num_entries);
    bf_map.num_entries++;
    init_proxy_conn_client_entry(&pcc_list, fd, 0, NULL);
}

/*
void socket_read_cb(struct bufferevent *bev, void *arg) {
    proxy_conn_t *conn = (proxy_conn_t *)arg;
    proxy_conn_info **lp = conn->conn_list;

    //char msg[4096];
    ssize_t len = bufferevent_read(bev, lp[0]->sndbuf, lp[0]->sndbufsize);
    lp[0]->sndbuf[len] = '\0';
    if (len >= 24) {
        binary_header_t *h = (binary_header_t *)lp[0]->sndbuf;
        int body_len = ntohl(h->body_len);
        fprintf(stderr, "req h=24 body_len=%d len=%zd\n", body_len, len);
        if (len == body_len + 24) {
            fprintf(stderr, "success receive client request len=%zd\n", len);
            fprintf(stderr, "prepare to send msg => (sfd=%u)\n", lp[0]->conn.sfd);
            send_request(lp[0]->conn.sfd, lp[0]->sndbuf, (size_t)len); //send to server
        } else {
            exit(EXIT_FAILURE);
        }
    } else {
        exit(EXIT_FAILURE);
    }
    len = recv_request(lp[0]->conn.sfd, lp[0]->rcvbuf, lp[0]->rcvbufsize);
    lp[0]->rcvbuf[len] = '\0';
    if (len >= 24) {
        binary_header_t *h = (binary_header_t *)lp[0]->rcvbuf;
        int body_len = ntohl(h->body_len);
        fprintf(stderr, "resp h=24 body_len=%d len=%zd\n", body_len, len);
        if (len == body_len + 24) {
            fprintf(stderr, "success receive client request body len=%d\n", body_len);
            bufferevent_write(bev, lp[0]->rcvbuf, len);
        }
    }
}*/

// proxy read client input
void proxy_r_cli_cb(struct bufferevent *bev, void *arg) {
    proxy_conn_server_list *pcs = (proxy_conn_server_list *)arg;
    char msg[1024];
    ssize_t len = bufferevent_read(bev, msg, 1024);
    assert(len >=24);
    msg[len] = '\0';

    binary_header_t *h = (binary_header_t *)msg;
    int body_len = ntohl(h->body_len);
    assert(len == body_len + 24);
    fprintf(stderr, "proxy_r_cli_cb() client_req h=%d body_len=%d len=%zd\n", 24, body_len, len);
    int src_fd = bufferevent_getfd(bev);
    //int src_fd_idx = -1;
    //for (int i=0; i<pcc_list.num_c_conn; i++) {
    //    if (src_fd == pcc_list.pcc_list[i]->sfd) {
    //        src_fd_idx = i;
    //        break;
    //    }
    //}

    uint16_t dst_port = get_port_from_resp_opaque(msg);
    int dst_fd_idx = -1;
    for (int i=0; i<pcs->num_s_conn; i++) {
        if (dst_port == pcs->pcs_list[i]->conn.port) {
            //dst_fd_idx = pts->dst_server->p_ctx[i]->conn_ctx.conn.sfd;
            dst_fd_idx = i;
            break;
        }
    }
    //change to client ip:port
    //fill_ip_port_to_req_h(msg, (uint16_t) pcc_list.pcc_list[src_fd_idx]->rank,
    //                      (uint16_t) pcc_list.pcc_list[src_fd_idx]->port);
    fill_sfd_to_req_h(msg, src_fd);
    fprintf(stderr, "FILL src_sfd=%d dst_fd_idx=%d dst_port=%d\n", src_fd, dst_fd_idx, dst_port);

    if (CMD_GET == h->opcode || CMD_GETK == h->opcode ||
            CMD_GETK == h->opcode) {
        //copy to corresponding connection buffer
        queue_buf qb = pcs->pcs_list[dst_fd_idx]->q_buf;
        if (0 == qb.num_rqs) {
            assert(0 == qb.buf_offset && 0 == qb.buf_len);
            qb.buf_offset = 24;  //leave header space
            qb.buf_len = 24;
        }
        memcpy(qb.buf + qb.buf_offset, msg, len);
        qb.buf_len += len;
        qb.buf_offset += len;
        qb.num_rqs += 1;

        if (qb.num_rqs >= 2 /*|| timeout */) {
            //set mget header
            compose_binary_mget_req_h(qb.buf, qb.num_rqs, qb.buf_len - 24);
            int ret = bufferevent_write(bev_s[dst_fd_idx], qb.buf, qb.buf_len);
            if (0 == ret) {
                qb.buf_len = 0;
                qb.buf_offset = 0;
                qb.num_rqs = 0;
                fprintf(stderr, "bufferevent_write success => server=%s fd=%d\n",
                        pcs->pcs_list[dst_fd_idx]->conn.server,
                        pcs->pcs_list[dst_fd_idx]->conn.sfd);
            } else {
                fprintf(stderr, "bufferevent_write error ret=%d\n", ret);
                exit(EXIT_FAILURE);
            }
        } else {
            fprintf(stderr, "bufferevent_write pending rqs=%d\n", qb.num_rqs);
        }
    } else {
        int ret = bufferevent_write(bev_s[dst_fd_idx], msg, len);
        fprintf(stderr, "simply transfer non-get rqs_len=%zd ret=%d\n", len, ret);
    }
}

void socket_event_cb(struct bufferevent *bev, short events, void *arg) {
    evutil_socket_t fd = bufferevent_getfd(bev);
    fprintf(stderr, "socket_event_cb fd=%u\n", fd);
    if (events & BEV_EVENT_EOF) {
        fprintf(stderr, "connection closed\n");
    } else if (events & BEV_EVENT_ERROR) {
        fprintf(stderr, "some other error\n");
    } else if (events & BEV_EVENT_TIMEOUT) {
        fprintf(stderr, "time out error\n");
    }
    bufferevent_free(bev);
}

// proxy read server input
void proxy_r_server_cb(struct bufferevent *bev, void *arg) {
    proxy_conn_client_list *pcc = (proxy_conn_client_list *)arg;
    char msg[1024];
    ssize_t len = bufferevent_read(bev, msg, 1024);
    assert(len >=24);
    msg[len] = '\0';

    binary_header_t *h = (binary_header_t *)msg;
    int body_len = ntohl(h->body_len);
    assert(len == body_len + 24);
    fprintf(stderr, "proxy_r_server_cb() s_resp_h=%d body_len=%d len=%zd batch_sz=%d\n",
            24, body_len, len, h->extra_len);
    int ret, dst_sfd, cp_len, dst_idx = -1;
    if (CMD_MGET == h->opcode) {
        int parsed_offset = 24;
        for (int i = 0; i < h->extra_len; i++) {
            h = (binary_header_t *) (msg + parsed_offset);
            dst_sfd = get_sfd_from_opaque(msg+parsed_offset);
            cp_len = 24 + ntohl(h->body_len);
            for (int i = 0; i < pcc->num_c_conn; i++) {
                if (dst_sfd == pcc->pcc_list[i]->sfd) {
                    dst_idx = i;
                    break;
                }
            }
            fprintf(stderr, "proxy_r_server_cb() dst_idx=%d dst_sfd=%d\n", dst_idx, dst_sfd);
            if (dst_idx >= 0) {
                ret = bufferevent_write(bev_c[dst_idx], msg + parsed_offset, cp_len);
                if (0 == ret) {
                    parsed_offset += cp_len;
                    fprintf(stderr, "bufferevent_write success => server=%s fd=%d\n",
                            pcc->pcc_list[dst_idx]->server, pcc->pcc_list[dst_idx]->sfd);
                } else {
                    fprintf(stderr, "bufferevent_write error ret=%d\n", ret);
                    exit(EXIT_FAILURE);
                }
            }
        }
    } else {
        dst_sfd = get_sfd_from_opaque(msg);
        for (int i = 0; i < pcc->num_c_conn; i++) {
            if (dst_sfd == pcc->pcc_list[i]->sfd) {
                dst_idx = i;
                break;
            }
        }
        fprintf(stderr, "dst_sfd=%d dst_idx=%d server=%s\n", dst_sfd, dst_idx,
        pcc->pcc_list[dst_idx]->server);
        ret = bufferevent_write(bev_c[dst_idx], msg, len);
        fprintf(stderr, "simply transfer non-mget resp_len=%zd ret=%d\n", len, ret);
    }
}


